#!/bin/bash

# Accelerate训练启动脚本

# 检查CUDA是否可用
if ! command -v nvidia-smi &> /dev/null
then
    echo "CUDA不可用，使用CPU训练"
    python scripts/train.py --config configs/accelerate_config.json
else
    # 获取GPU数量
    NUM_GPUS=$(nvidia-smi --query-gpu=count --format=csv,noheader,nounits | head -1)
    echo "检测到 $NUM_GPUS 个GPU"
    
    if [ $NUM_GPUS -gt 1 ]; then
        # 多GPU训练
        echo "启动Accelerate多GPU训练"
        accelerate launch --multi_gpu --num_processes=$NUM_GPUS scripts/train.py --config configs/accelerate_config.json
    else
        # 单GPU训练
        echo "启动单GPU训练"
        python scripts/train.py --config configs/accelerate_config.json
    fi
fi